{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "Fine-tuning LayoutLMv2ForTokenClassification on CORD.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "mount_file_id": "1O-rUeZpgr0o6vqd5m2VpfVoYJBrCtT7j",
      "authorship_tag": "ABX9TyNEjJ6bJ2PgIyOfsIcQELwW",
      "include_colab_link": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "128dbaf10f724ba7bdf57da3910e7128": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_d6778ed2a24840eb9cb2a85df28531e6",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_4cd0bd32d51541879a2f0a636638b19d",
              "IPY_MODEL_c237fb40ccff49dba279b91d1923d0e2",
              "IPY_MODEL_43c58c208c904fed85d14271bec6cd33"
            ]
          }
        },
        "d6778ed2a24840eb9cb2a85df28531e6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "4cd0bd32d51541879a2f0a636638b19d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_43ebe3095db34987a870d6275c7cff8d",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": "100%",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_525b1054deb84a68b763ac1240304d8c"
          }
        },
        "c237fb40ccff49dba279b91d1923d0e2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_3655ce008da64cdeae1f613723c5c50f",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 400,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 400,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_cfdf96fd74e04945b0fe1cade6143d80"
          }
        },
        "43c58c208c904fed85d14271bec6cd33": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_f1a6df2cc9d84611bc44060a02910fc5",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 400/400 [11:59&lt;00:00,  2.04s/it]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_4c9b7c05398e4346bccd6e163a9ab7f1"
          }
        },
        "43ebe3095db34987a870d6275c7cff8d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "525b1054deb84a68b763ac1240304d8c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "3655ce008da64cdeae1f613723c5c50f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "cfdf96fd74e04945b0fe1cade6143d80": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "f1a6df2cc9d84611bc44060a02910fc5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "4c9b7c05398e4346bccd6e163a9ab7f1": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "32eadf9c95fb40a5b96799a98194bd62": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_1ecb1402abe84ac7a837fa2fc03bd55f",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_47c021d9c0924f93a97ee063bceb736b",
              "IPY_MODEL_1d89ea05fd2f49a89013d52dc547a230",
              "IPY_MODEL_899c9589cd4d445994adbfea77cbab84"
            ]
          }
        },
        "1ecb1402abe84ac7a837fa2fc03bd55f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "47c021d9c0924f93a97ee063bceb736b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_d783e756b2fe491ca2189d7b6fc1f4b2",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": "100%",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_822c074ff17c4dc596af383c1553ef8c"
          }
        },
        "1d89ea05fd2f49a89013d52dc547a230": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_801ea8bc95904e058c99d574597018fb",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 400,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 400,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_64c098723c024c9e9a38977fb3b8c59f"
          }
        },
        "899c9589cd4d445994adbfea77cbab84": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_7c23545e83ac4b9db7cc5d1f599b2456",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 400/400 [03:49&lt;00:00,  1.85it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_942af94819dd4c378e69558a6b49d5d9"
          }
        },
        "d783e756b2fe491ca2189d7b6fc1f4b2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "822c074ff17c4dc596af383c1553ef8c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "801ea8bc95904e058c99d574597018fb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "64c098723c024c9e9a38977fb3b8c59f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "7c23545e83ac4b9db7cc5d1f599b2456": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "942af94819dd4c378e69558a6b49d5d9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "587359bd11274d338aab9e7f56e94488": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_da1e33a04c584f6998afe2aa0972c8d6",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_bbd4380e27264c83b948da7ae52c2c59",
              "IPY_MODEL_54f2a1073ca04662b85f12eefd3b527b",
              "IPY_MODEL_8812c60837aa42f792807507dc4ff477"
            ]
          }
        },
        "da1e33a04c584f6998afe2aa0972c8d6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "bbd4380e27264c83b948da7ae52c2c59": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_847a27c4e2f84db0af7e4f58e3bedfc7",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": "100%",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_7010e46c84794c55ac9c9d8aae59a9ba"
          }
        },
        "54f2a1073ca04662b85f12eefd3b527b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_e6697b897bbf4faf8067a1466b0337ea",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 400,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 400,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_b839fadf6ad144078a8c7b8acec74933"
          }
        },
        "8812c60837aa42f792807507dc4ff477": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_653a75472dca4ca285f91926cb044e90",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 400/400 [03:51&lt;00:00,  1.60it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_93a08408d1da4e3680c20f631a2bf4e9"
          }
        },
        "847a27c4e2f84db0af7e4f58e3bedfc7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "7010e46c84794c55ac9c9d8aae59a9ba": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "e6697b897bbf4faf8067a1466b0337ea": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "b839fadf6ad144078a8c7b8acec74933": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "653a75472dca4ca285f91926cb044e90": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "93a08408d1da4e3680c20f631a2bf4e9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "36cce2a65b8a4f6b983ae0671a06d116": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_cb0ea4ec80b841baadf53fc703f99557",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_8deed5f919c04bcab95316ad597e7a2b",
              "IPY_MODEL_78104f8f83cc4054a1067e8d2a6a4e4e",
              "IPY_MODEL_bed2b019452c496eb1b1cf74d1a280f3"
            ]
          }
        },
        "cb0ea4ec80b841baadf53fc703f99557": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "8deed5f919c04bcab95316ad597e7a2b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_2ddd879ca3604495954fcc9c83cd3b37",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": "100%",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_cf40611daddf458b932390aa47fbe124"
          }
        },
        "78104f8f83cc4054a1067e8d2a6a4e4e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_052a590591af4ea4a3d6a7f3ab3a17de",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 400,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 400,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_f6ccafb933e6452a931458c651ed565a"
          }
        },
        "bed2b019452c496eb1b1cf74d1a280f3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_e647d3a440994156b07f3271244a980c",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 400/400 [03:51&lt;00:00,  1.89it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_f49d60f8f1de48b5b3b83dc0fa7f4911"
          }
        },
        "2ddd879ca3604495954fcc9c83cd3b37": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "cf40611daddf458b932390aa47fbe124": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "052a590591af4ea4a3d6a7f3ab3a17de": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "f6ccafb933e6452a931458c651ed565a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "e647d3a440994156b07f3271244a980c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "f49d60f8f1de48b5b3b83dc0fa7f4911": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "4bd5fc6d3bb64279b8594b0cffbb5d59": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_97241634272d48afa7e5f4eef7b5311f",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_ab33f31ca74740dd8ca9fc98cc5bdb25",
              "IPY_MODEL_24a78c17aec14d308080ac07d36865dd",
              "IPY_MODEL_0aee29ed9f394ccf81c126466f78b7c2"
            ]
          }
        },
        "97241634272d48afa7e5f4eef7b5311f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "ab33f31ca74740dd8ca9fc98cc5bdb25": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_e90a122a051b44fea59b3c7da4754f6d",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": "Evaluating: 100%",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_584c91d192ce452698d067c3998adde8"
          }
        },
        "24a78c17aec14d308080ac07d36865dd": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_9a16c3df03304d39a028dce62b5187b1",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 50,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 50,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_2e6f420dfdc64372b0fdc497cd3543e3"
          }
        },
        "0aee29ed9f394ccf81c126466f78b7c2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_01b68a67df974c6991b9716e211f32f1",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 50/50 [01:48&lt;00:00,  2.19s/it]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_1e85a9c975c244768006301dc00c7a31"
          }
        },
        "e90a122a051b44fea59b3c7da4754f6d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "584c91d192ce452698d067c3998adde8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "9a16c3df03304d39a028dce62b5187b1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "2e6f420dfdc64372b0fdc497cd3543e3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "01b68a67df974c6991b9716e211f32f1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "1e85a9c975c244768006301dc00c7a31": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/LayoutLMv2/CORD/Fine_tuning_LayoutLMv2ForTokenClassification_on_CORD.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z7E3hFdIQBYu"
      },
      "source": [
        "In this notebook, we are going to fine-tune `LayoutLMv2ForTokenClassification` on the [CORD](https://github.com/clovaai/cord) dataset. The goal for the model is to label words appearing in scanned documents (namely, receipts) appropriately. This task is treated as a NER problem (sequence labeling). However, compared to BERT, LayoutLMv2 also incorporates visual and layout information about the tokens when encoding them into vectors. This makes the LayoutLMv2 model very powerful for document understanding tasks.\n",
        "\n",
        "LayoutLMv2 is itself an upgrade of LayoutLM. The main novelty of LayoutLMv2 is that it also pre-trains visual embeddings, whereas the original LayoutLM only adds visual embeddings during fine-tuning.\n",
        "\n",
        "* Paper: https://arxiv.org/abs/2012.14740\n",
        "* Original repo: https://github.com/microsoft/unilm/tree/master/layoutlmv2\n",
        "\n",
        "NOTES: \n",
        "\n",
        "* you first need to prepare the CORD dataset for LayoutLMv2. For that, check out the notebook \"Prepare CORD for LayoutLMv2\".\n",
        "* this notebook is heavily inspired by [this Github repository](https://github.com/omarsou/layoutlm_CORD), which fine-tunes both BERT and LayoutLM (v1) on the CORD dataset.\n",
        "\n",
        "\n",
        "\n",
        "## Install dependencies\n",
        "\n",
        "First, we install the required libraries:\n",
        "* Transformers (for the LayoutLMv2 model)\n",
        "* Datasets (for data preprocessing)\n",
        "* Seqeval (for metrics)\n",
        "* Detectron2 (which LayoutLMv2 requires for its visual backbone).\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vL18fbJlmXn8",
        "outputId": "e2d952fd-71c6-460a-cc1e-17a292cfe3f8"
      },
      "source": [
        "!rm -r transformers\n",
        "!git clone -b modeling_layoutlmv2_v2 https://github.com/NielsRogge/transformers.git\n",
        "!cd tranformers\n",
        "!pip install -q ./transformers "
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Cloning into 'transformers'...\n",
            "remote: Enumerating objects: 83407, done.\u001b[K\n",
            "remote: Counting objects: 100% (1845/1845), done.\u001b[K\n",
            "remote: Compressing objects: 100% (529/529), done.\u001b[K\n",
            "remote: Total 83407 (delta 1195), reused 1697 (delta 1099), pack-reused 81562\u001b[K\n",
            "Receiving objects: 100% (83407/83407), 63.59 MiB | 27.13 MiB/s, done.\n",
            "Resolving deltas: 100% (59481/59481), done.\n",
            "/bin/bash: line 0: cd: tranformers: No such file or directory\n",
            "\u001b[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.\n",
            "   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.\u001b[0m\n",
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "    Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for transformers (PEP 517) ... \u001b[?25l\u001b[?25hdone\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Byx9C8bmGVLe"
      },
      "source": [
        "!pip install -q datasets seqeval"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "B1CNb7_MP2Z-",
        "outputId": "840e7367-be87-4f91-fe85-87f10a9dacd4"
      },
      "source": [
        "!pip install pyyaml==5.1\n",
        "# workaround: install old version of pytorch since detectron2 hasn't released packages for pytorch 1.9 (issue: https://github.com/facebookresearch/detectron2/issues/3158)\n",
        "!pip install torch==1.8.0+cu101 torchvision==0.9.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html\n",
        "\n",
        "# install detectron2 that matches pytorch 1.8\n",
        "# See https://detectron2.readthedocs.io/tutorials/install.html for instructions\n",
        "!pip install -q detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.8/index.html\n",
        "# exit(0)  # After installation, you need to \"restart runtime\" in Colab. This line can also restart runtime"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: pyyaml==5.1 in /usr/local/lib/python3.7/dist-packages (5.1)\n",
            "Looking in links: https://download.pytorch.org/whl/torch_stable.html\n",
            "Requirement already satisfied: torch==1.8.0+cu101 in /usr/local/lib/python3.7/dist-packages (1.8.0+cu101)\n",
            "Requirement already satisfied: torchvision==0.9.0+cu101 in /usr/local/lib/python3.7/dist-packages (0.9.0+cu101)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch==1.8.0+cu101) (3.7.4.3)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torch==1.8.0+cu101) (1.19.5)\n",
            "Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.7/dist-packages (from torchvision==0.9.0+cu101) (7.1.2)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iOxLV0JnQJWJ"
      },
      "source": [
        "## Prepare the data\n",
        "\n",
        "First, let's read in the annotations which we prepared in the other notebook. These contain the word-level annotations (words, labels, normalized bounding boxes)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "f4ZyOsDcEGQL"
      },
      "source": [
        "import pandas as pd\n",
        "\n",
        "train = pd.read_pickle('/content/drive/MyDrive/LayoutLMv2/Tutorial notebooks/CORD/CORD_layoutlmv2_format/train.pkl')\n",
        "val = pd.read_pickle('/content/drive/MyDrive/LayoutLMv2/Tutorial notebooks/CORD/CORD_layoutlmv2_format/dev.pkl')\n",
        "test = pd.read_pickle('/content/drive/MyDrive/LayoutLMv2/Tutorial notebooks/CORD/CORD_layoutlmv2_format/test.pkl')"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1cXWGK_yOuOf"
      },
      "source": [
        "Let's define a list of all unique labels. For that, let's first count the number of occurrences for each label:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1RJLFClPIYjt",
        "outputId": "e033fd1f-0fcc-4a8c-9287-838cdcf1e57e"
      },
      "source": [
        "from collections import Counter\n",
        "\n",
        "all_labels = [item for sublist in train[1] for item in sublist] + [item for sublist in val[1] for item in sublist] + [item for sublist in test[1] for item in sublist]\n",
        "Counter(all_labels)"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Counter({'menu.cnt': 2429,\n",
              "         'menu.discountprice': 403,\n",
              "         'menu.etc': 19,\n",
              "         'menu.itemsubtotal': 7,\n",
              "         'menu.nm': 6597,\n",
              "         'menu.num': 109,\n",
              "         'menu.price': 2585,\n",
              "         'menu.sub_cnt': 189,\n",
              "         'menu.sub_etc': 9,\n",
              "         'menu.sub_nm': 822,\n",
              "         'menu.sub_price': 160,\n",
              "         'menu.sub_unitprice': 14,\n",
              "         'menu.unitprice': 750,\n",
              "         'menu.vatyn': 9,\n",
              "         'sub_total.discount_price': 191,\n",
              "         'sub_total.etc': 283,\n",
              "         'sub_total.othersvc_price': 6,\n",
              "         'sub_total.service_price': 353,\n",
              "         'sub_total.subtotal_price': 1482,\n",
              "         'sub_total.tax_price': 1283,\n",
              "         'total.cashprice': 1393,\n",
              "         'total.changeprice': 1297,\n",
              "         'total.creditcardprice': 410,\n",
              "         'total.emoneyprice': 129,\n",
              "         'total.menuqty_cnt': 630,\n",
              "         'total.menutype_cnt': 130,\n",
              "         'total.total_etc': 89,\n",
              "         'total.total_price': 2120,\n",
              "         'void_menu.nm': 3,\n",
              "         'void_menu.price': 1})"
            ]
          },
          "metadata": {},
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h5jvOQR2Ivs5"
      },
      "source": [
        "As we can see, there are some labels that contain very few examples. Let's replace them by the \"neutral\" label \"O\" (which stands for \"Outside\")."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7UzRGzqlIx22"
      },
      "source": [
        "replacing_labels = {'menu.etc': 'O', 'mneu.itemsubtotal': 'O', 'menu.sub_etc': 'O', 'menu.sub_unitprice': 'O', 'menu.vatyn': 'O',\n",
        "                  'void_menu.nm': 'O', 'void_menu.price': 'O', 'sub_total.othersvc_price': 'O'}"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ldWhTh1wIyX-"
      },
      "source": [
        "def replace_elem(elem):\n",
        "  try:\n",
        "    return replacing_labels[elem]\n",
        "  except KeyError:\n",
        "    return elem\n",
        "def replace_list(ls):\n",
        "  return [replace_elem(elem) for elem in ls]\n",
        "train[1] = [replace_list(ls) for ls in train[1]]\n",
        "val[1] = [replace_list(ls) for ls in val[1]]\n",
        "test[1] = [replace_list(ls) for ls in test[1]]"
      ],
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YIN8ptH4JD3T",
        "outputId": "f14b6cab-383f-4dc7-ddb2-a27f33b50858"
      },
      "source": [
        "all_labels = [item for sublist in train[1] for item in sublist] + [item for sublist in val[1] for item in sublist] + [item for sublist in test[1] for item in sublist]\n",
        "Counter(all_labels)"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Counter({'O': 61,\n",
              "         'menu.cnt': 2429,\n",
              "         'menu.discountprice': 403,\n",
              "         'menu.itemsubtotal': 7,\n",
              "         'menu.nm': 6597,\n",
              "         'menu.num': 109,\n",
              "         'menu.price': 2585,\n",
              "         'menu.sub_cnt': 189,\n",
              "         'menu.sub_nm': 822,\n",
              "         'menu.sub_price': 160,\n",
              "         'menu.unitprice': 750,\n",
              "         'sub_total.discount_price': 191,\n",
              "         'sub_total.etc': 283,\n",
              "         'sub_total.service_price': 353,\n",
              "         'sub_total.subtotal_price': 1482,\n",
              "         'sub_total.tax_price': 1283,\n",
              "         'total.cashprice': 1393,\n",
              "         'total.changeprice': 1297,\n",
              "         'total.creditcardprice': 410,\n",
              "         'total.emoneyprice': 129,\n",
              "         'total.menuqty_cnt': 630,\n",
              "         'total.menutype_cnt': 130,\n",
              "         'total.total_etc': 89,\n",
              "         'total.total_price': 2120})"
            ]
          },
          "metadata": {},
          "execution_count": 8
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "shn-dMRbJGuu"
      },
      "source": [
        "Now we have to save all the unique labels in a list."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EyK3lnaCJHWn",
        "outputId": "26bc2648-5043-445f-d49c-719360a55364"
      },
      "source": [
        "labels = list(set(all_labels))\n",
        "print(labels)"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['total.creditcardprice', 'menu.num', 'total.total_etc', 'menu.cnt', 'menu.sub_cnt', 'total.menutype_cnt', 'total.menuqty_cnt', 'menu.discountprice', 'menu.sub_nm', 'total.changeprice', 'menu.sub_price', 'sub_total.service_price', 'menu.itemsubtotal', 'menu.unitprice', 'sub_total.subtotal_price', 'O', 'sub_total.etc', 'sub_total.tax_price', 'sub_total.discount_price', 'menu.nm', 'total.emoneyprice', 'menu.price', 'total.cashprice', 'total.total_price']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lvYHvmk1J2a2",
        "outputId": "3fbced2d-43a7-42fd-97e6-dde358e56266"
      },
      "source": [
        "label2id = {label: idx for idx, label in enumerate(labels)}\n",
        "id2label = {idx: label for idx, label in enumerate(labels)}\n",
        "print(label2id)\n",
        "print(id2label)"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "{'total.creditcardprice': 0, 'menu.num': 1, 'total.total_etc': 2, 'menu.cnt': 3, 'menu.sub_cnt': 4, 'total.menutype_cnt': 5, 'total.menuqty_cnt': 6, 'menu.discountprice': 7, 'menu.sub_nm': 8, 'total.changeprice': 9, 'menu.sub_price': 10, 'sub_total.service_price': 11, 'menu.itemsubtotal': 12, 'menu.unitprice': 13, 'sub_total.subtotal_price': 14, 'O': 15, 'sub_total.etc': 16, 'sub_total.tax_price': 17, 'sub_total.discount_price': 18, 'menu.nm': 19, 'total.emoneyprice': 20, 'menu.price': 21, 'total.cashprice': 22, 'total.total_price': 23}\n",
            "{0: 'total.creditcardprice', 1: 'menu.num', 2: 'total.total_etc', 3: 'menu.cnt', 4: 'menu.sub_cnt', 5: 'total.menutype_cnt', 6: 'total.menuqty_cnt', 7: 'menu.discountprice', 8: 'menu.sub_nm', 9: 'total.changeprice', 10: 'menu.sub_price', 11: 'sub_total.service_price', 12: 'menu.itemsubtotal', 13: 'menu.unitprice', 14: 'sub_total.subtotal_price', 15: 'O', 16: 'sub_total.etc', 17: 'sub_total.tax_price', 18: 'sub_total.discount_price', 19: 'menu.nm', 20: 'total.emoneyprice', 21: 'menu.price', 22: 'total.cashprice', 23: 'total.total_price'}\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Gq5CmIRZUy9O"
      },
      "source": [
        "from os import listdir\n",
        "from torch.utils.data import Dataset\n",
        "import torch\n",
        "from PIL import Image\n",
        "\n",
        "class CORDDataset(Dataset):\n",
        "    \"\"\"CORD dataset.\"\"\"\n",
        "\n",
        "    def __init__(self, annotations, image_dir, processor=None, max_length=512):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            annotations (List[List]): List of lists containing the word-level annotations (words, labels, boxes).\n",
        "            image_dir (string): Directory with all the document images.\n",
        "            processor (LayoutLMv2Processor): Processor to prepare the text + image.\n",
        "        \"\"\"\n",
        "        self.words, self.labels, self.boxes = annotations\n",
        "        self.image_dir = image_dir\n",
        "        self.image_file_names = [f for f in listdir(image_dir)]\n",
        "        self.processor = processor\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.image_file_names)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        # first, take an image\n",
        "        item = self.image_file_names[idx]\n",
        "        image = Image.open(self.image_dir + item).convert(\"RGB\")\n",
        "\n",
        "        # get word-level annotations \n",
        "        words = self.words[idx]\n",
        "        boxes = self.boxes[idx]\n",
        "        word_labels = self.labels[idx]\n",
        "\n",
        "        assert len(words) == len(boxes) == len(word_labels)\n",
        "        \n",
        "        word_labels = [label2id[label] for label in word_labels]\n",
        "        # use processor to prepare everything\n",
        "        encoded_inputs = self.processor(image, words, boxes=boxes, word_labels=word_labels, \n",
        "                                        padding=\"max_length\", truncation=True, \n",
        "                                        return_tensors=\"pt\")\n",
        "        \n",
        "        # remove batch dimension\n",
        "        for k,v in encoded_inputs.items():\n",
        "          encoded_inputs[k] = v.squeeze()\n",
        "\n",
        "        assert encoded_inputs.input_ids.shape == torch.Size([512])\n",
        "        assert encoded_inputs.attention_mask.shape == torch.Size([512])\n",
        "        assert encoded_inputs.token_type_ids.shape == torch.Size([512])\n",
        "        assert encoded_inputs.bbox.shape == torch.Size([512, 4])\n",
        "        assert encoded_inputs.image.shape == torch.Size([3, 224, 224])\n",
        "        assert encoded_inputs.labels.shape == torch.Size([512]) \n",
        "      \n",
        "        return encoded_inputs"
      ],
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "izA9k5e-egyy"
      },
      "source": [
        "from transformers import LayoutLMv2Processor\n",
        "\n",
        "processor = LayoutLMv2Processor.from_pretrained(\"microsoft/layoutlmv2-base-uncased\", revision=\"no_ocr\")\n",
        "\n",
        "train_dataset = CORDDataset(annotations=train,\n",
        "                            image_dir='/content/drive/MyDrive/LayoutLMv2/Tutorial notebooks/CORD/CORD/train/image/', \n",
        "                            processor=processor)\n",
        "val_dataset = CORDDataset(annotations=val,\n",
        "                            image_dir='/content/drive/MyDrive/LayoutLMv2/Tutorial notebooks/CORD/CORD/dev/image/', \n",
        "                            processor=processor)\n",
        "test_dataset = CORDDataset(annotations=test,\n",
        "                            image_dir='/content/drive/MyDrive/LayoutLMv2/Tutorial notebooks/CORD/CORD/test/image/', \n",
        "                            processor=processor)"
      ],
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9qGt51RyetD2"
      },
      "source": [
        "Let's verify an example:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "pHg2--KMeuRN",
        "outputId": "8af0b314-5c65-422f-a43e-6e9fa09780aa"
      },
      "source": [
        "encoding = train_dataset[0]\n",
        "encoding.keys()"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'bbox', 'labels', 'image'])"
            ]
          },
          "metadata": {},
          "execution_count": 13
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6bpvNOncff8J",
        "outputId": "e3c2baa7-7ac0-424e-fdae-0bfb21d6cefe"
      },
      "source": [
        "for k,v in encoding.items():\n",
        "  print(k, v.shape)"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "input_ids torch.Size([512])\n",
            "token_type_ids torch.Size([512])\n",
            "attention_mask torch.Size([512])\n",
            "bbox torch.Size([512, 4])\n",
            "labels torch.Size([512])\n",
            "image torch.Size([3, 224, 224])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Uxnt2KXBMSxn",
        "outputId": "13702989-2f3b-4b12-c310-f8702ea51e57"
      },
      "source": [
        "print(processor.tokenizer.decode(encoding['input_ids']))"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[CLS] tebu lemon 1 22. 000 total 22. 000 cash 22. 000 1 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3Ct1Ut9trlbn",
        "outputId": "0a41d2b9-7093-434c-ae0d-ee20770cc016"
      },
      "source": [
        "train[0][0]"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['Tebu', 'Lemon', '1', '22.000', 'Total', '22.000', 'CASH', '22.000', '1']"
            ]
          },
          "metadata": {},
          "execution_count": 16
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vFzAqOpUrojk",
        "outputId": "27a66012-27ba-49f5-cb0b-758717a3db64"
      },
      "source": [
        "train[1][0]"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['menu.nm',\n",
              " 'menu.nm',\n",
              " 'menu.cnt',\n",
              " 'menu.price',\n",
              " 'total.total_price',\n",
              " 'total.total_price',\n",
              " 'total.cashprice',\n",
              " 'total.cashprice',\n",
              " 'total.menuqty_cnt']"
            ]
          },
          "metadata": {},
          "execution_count": 17
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "laq33rnBri4T",
        "outputId": "41da7ca1-5999-4edb-df31-9c6c438fd7ed"
      },
      "source": [
        "[id2label[label] for label in encoding['labels'].tolist() if label != -100]"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['menu.nm',\n",
              " 'menu.nm',\n",
              " 'menu.cnt',\n",
              " 'menu.price',\n",
              " 'total.total_price',\n",
              " 'total.total_price',\n",
              " 'total.cashprice',\n",
              " 'total.cashprice',\n",
              " 'total.menuqty_cnt']"
            ]
          },
          "metadata": {},
          "execution_count": 18
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qkCH2Q-_MnRV",
        "outputId": "1b5bce0b-c505-463e-fed0-90c71614f120"
      },
      "source": [
        "for id, label in zip(encoding['input_ids'][:30], encoding['labels'][:30]):\n",
        "  print(processor.tokenizer.decode([id]), label.item())"
      ],
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[CLS] -100\n",
            "te 19\n",
            "##bu -100\n",
            "lemon 19\n",
            "1 3\n",
            "22 21\n",
            ". -100\n",
            "000 -100\n",
            "total 23\n",
            "22 23\n",
            ". -100\n",
            "000 -100\n",
            "cash 22\n",
            "22 22\n",
            ". -100\n",
            "000 -100\n",
            "1 6\n",
            "[SEP] -100\n",
            "[PAD] -100\n",
            "[PAD] -100\n",
            "[PAD] -100\n",
            "[PAD] -100\n",
            "[PAD] -100\n",
            "[PAD] -100\n",
            "[PAD] -100\n",
            "[PAD] -100\n",
            "[PAD] -100\n",
            "[PAD] -100\n",
            "[PAD] -100\n",
            "[PAD] -100\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PAQJLuPPf27V"
      },
      "source": [
        "Next, we create corresponding dataloaders."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XI-eM0m4fmfF"
      },
      "source": [
        "from torch.utils.data import DataLoader\n",
        "\n",
        "train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)\n",
        "val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=True)\n",
        "test_dataloader = DataLoader(test_dataset, batch_size=2)"
      ],
      "execution_count": 20,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hHgKGixHgIIC"
      },
      "source": [
        "## Train the model\n",
        "\n",
        "Let's train the model using native PyTorch. We use the AdamW optimizer with learning rate = 5e-5 (this is a good default value when fine-tuning Transformer-based models).\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 590,
          "referenced_widgets": [
            "128dbaf10f724ba7bdf57da3910e7128",
            "d6778ed2a24840eb9cb2a85df28531e6",
            "4cd0bd32d51541879a2f0a636638b19d",
            "c237fb40ccff49dba279b91d1923d0e2",
            "43c58c208c904fed85d14271bec6cd33",
            "43ebe3095db34987a870d6275c7cff8d",
            "525b1054deb84a68b763ac1240304d8c",
            "3655ce008da64cdeae1f613723c5c50f",
            "cfdf96fd74e04945b0fe1cade6143d80",
            "f1a6df2cc9d84611bc44060a02910fc5",
            "4c9b7c05398e4346bccd6e163a9ab7f1",
            "32eadf9c95fb40a5b96799a98194bd62",
            "1ecb1402abe84ac7a837fa2fc03bd55f",
            "47c021d9c0924f93a97ee063bceb736b",
            "1d89ea05fd2f49a89013d52dc547a230",
            "899c9589cd4d445994adbfea77cbab84",
            "d783e756b2fe491ca2189d7b6fc1f4b2",
            "822c074ff17c4dc596af383c1553ef8c",
            "801ea8bc95904e058c99d574597018fb",
            "64c098723c024c9e9a38977fb3b8c59f",
            "7c23545e83ac4b9db7cc5d1f599b2456",
            "942af94819dd4c378e69558a6b49d5d9",
            "587359bd11274d338aab9e7f56e94488",
            "da1e33a04c584f6998afe2aa0972c8d6",
            "bbd4380e27264c83b948da7ae52c2c59",
            "54f2a1073ca04662b85f12eefd3b527b",
            "8812c60837aa42f792807507dc4ff477",
            "847a27c4e2f84db0af7e4f58e3bedfc7",
            "7010e46c84794c55ac9c9d8aae59a9ba",
            "e6697b897bbf4faf8067a1466b0337ea",
            "b839fadf6ad144078a8c7b8acec74933",
            "653a75472dca4ca285f91926cb044e90",
            "93a08408d1da4e3680c20f631a2bf4e9",
            "36cce2a65b8a4f6b983ae0671a06d116",
            "cb0ea4ec80b841baadf53fc703f99557",
            "8deed5f919c04bcab95316ad597e7a2b",
            "78104f8f83cc4054a1067e8d2a6a4e4e",
            "bed2b019452c496eb1b1cf74d1a280f3",
            "2ddd879ca3604495954fcc9c83cd3b37",
            "cf40611daddf458b932390aa47fbe124",
            "052a590591af4ea4a3d6a7f3ab3a17de",
            "f6ccafb933e6452a931458c651ed565a",
            "e647d3a440994156b07f3271244a980c",
            "f49d60f8f1de48b5b3b83dc0fa7f4911"
          ]
        },
        "id": "NxahVHZ0NKq7",
        "outputId": "649a7b81-800c-4436-abf8-df11fb00981f"
      },
      "source": [
        "from transformers import LayoutLMv2ForTokenClassification, AdamW\n",
        "import torch\n",
        "from tqdm.notebook import tqdm\n",
        "\n",
        "model = LayoutLMv2ForTokenClassification.from_pretrained('microsoft/layoutlmv2-base-uncased',\n",
        "                                                                      num_labels=len(labels))\n",
        "\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "model.to(device)\n",
        "optimizer = AdamW(model.parameters(), lr=5e-5)\n",
        "\n",
        "global_step = 0\n",
        "num_train_epochs = 4\n",
        "\n",
        "#put the model in training mode\n",
        "model.train() \n",
        "for epoch in range(num_train_epochs):  \n",
        "   print(\"Epoch:\", epoch)\n",
        "   for batch in tqdm(train_dataloader):\n",
        "        # get the inputs;\n",
        "        input_ids = batch['input_ids'].to(device)\n",
        "        bbox = batch['bbox'].to(device)\n",
        "        image = batch['image'].to(device)\n",
        "        attention_mask = batch['attention_mask'].to(device)\n",
        "        token_type_ids = batch['token_type_ids'].to(device)\n",
        "        labels = batch['labels'].to(device)\n",
        "\n",
        "        # zero the parameter gradients\n",
        "        optimizer.zero_grad()\n",
        "        \n",
        "        # forward + backward + optimize\n",
        "        outputs = model(input_ids=input_ids,\n",
        "                        bbox=bbox,\n",
        "                        image=image,\n",
        "                        attention_mask=attention_mask,\n",
        "                        token_type_ids=token_type_ids,\n",
        "                        labels=labels) \n",
        "        loss = outputs.loss\n",
        "        \n",
        "        # print loss every 100 steps\n",
        "        if global_step % 100 == 0:\n",
        "          print(f\"Loss after {global_step} steps: {loss.item()}\")\n",
        "\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        global_step += 1\n",
        "\n",
        "model.save_pretrained(\"/content/drive/MyDrive/LayoutLMv2/Tutorial notebooks/CORD/Checkpoints\")"
      ],
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Some weights of the model checkpoint at microsoft/layoutlmv2-base-uncased were not used when initializing LayoutLMv2ForTokenClassification: ['layoutlmv2.visual.backbone.bottom_up.res4.9.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.15.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.7.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.19.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.8.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.2.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.3.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.7.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.13.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.0.shortcut.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.13.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.0.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.6.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.0.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.11.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.3.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.1.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.stem.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.1.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.1.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.18.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.22.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.10.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.2.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.0.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.2.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.18.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.21.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.14.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.2.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.9.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.17.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.11.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.12.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.1.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.2.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.1.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.6.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.20.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.4.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.17.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.2.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.13.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.17.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.2.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.8.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.12.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.8.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.16.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.5.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.0.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.16.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.2.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.3.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.2.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.0.shortcut.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.19.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.18.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.0.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.15.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.20.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.7.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.3.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.3.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.2.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.0.shortcut.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.14.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.14.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.22.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.4.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.0.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.1.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.20.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.21.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.1.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.19.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.22.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.1.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.10.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.21.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.0.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.1.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.1.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.16.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.0.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.3.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.0.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.2.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.5.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.10.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.9.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.0.shortcut.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.4.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.0.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.15.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.6.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.1.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.2.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.1.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.0.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.11.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.5.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.0.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.12.conv2.norm.num_batches_tracked']\n",
            "- This IS expected if you are initializing LayoutLMv2ForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
            "- This IS NOT expected if you are initializing LayoutLMv2ForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
            "Some weights of LayoutLMv2ForTokenClassification were not initialized from the model checkpoint at microsoft/layoutlmv2-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Epoch: 0\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "128dbaf10f724ba7bdf57da3910e7128",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "  0%|          | 0/400 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "text": [
            "Loss after 0 steps: 3.1703407764434814\n",
            "Loss after 100 steps: 2.3077478408813477\n",
            "Loss after 200 steps: 1.6867226362228394\n",
            "Loss after 300 steps: 1.0950108766555786\n",
            "Epoch: 1\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "32eadf9c95fb40a5b96799a98194bd62",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "  0%|          | 0/400 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "text": [
            "Loss after 400 steps: 2.3882484436035156\n",
            "Loss after 500 steps: 0.8808697462081909\n",
            "Loss after 600 steps: 0.9903069734573364\n",
            "Loss after 700 steps: 0.5121206641197205\n",
            "Epoch: 2\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "587359bd11274d338aab9e7f56e94488",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "  0%|          | 0/400 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "text": [
            "Loss after 800 steps: 0.23476076126098633\n",
            "Loss after 900 steps: 0.2449778914451599\n",
            "Loss after 1000 steps: 0.3742366433143616\n",
            "Loss after 1100 steps: 0.2236849069595337\n",
            "Epoch: 3\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "36cce2a65b8a4f6b983ae0671a06d116",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "  0%|          | 0/400 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "text": [
            "Loss after 1200 steps: 0.8207438588142395\n",
            "Loss after 1300 steps: 0.31655505299568176\n",
            "Loss after 1400 steps: 0.4179667830467224\n",
            "Loss after 1500 steps: 0.04862891137599945\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jDREDNOhPv0C"
      },
      "source": [
        "## Evaluation\n",
        "\n",
        "Let's evaluate the model on the test set. First, let's do a sanity check on the first example of the test set."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 171
        },
        "id": "8BGN3Yo6geJT",
        "outputId": "fb987c48-6df4-447b-b946-b9b6273e4220"
      },
      "source": [
        "encoding = test_dataset[0]\n",
        "processor.tokenizer.decode(encoding['input_ids'])"
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "\"[CLS] rp. goblin's mace 1 25, 000 mozarella hot dog 2 38, 000 chili pepper croquette 1 14, 000 cheese croquette 1 14, 000 plastik amook 1 0 plastik putih take away 1 0 subtotal 91, 000 discount ( 0 ) total 91, 000 debit 91, 000 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]\""
            ]
          },
          "metadata": {},
          "execution_count": 22
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-xkVMaHXyOVO",
        "outputId": "6c2fed19-73ed-4cac-f8be-a9289d3ea5a5"
      },
      "source": [
        "ground_truth_labels = [id2label[label] for label in encoding['labels'].squeeze().tolist() if label != -100]\n",
        "print(ground_truth_labels)"
      ],
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['total.total_price', 'menu.nm', 'menu.nm', 'menu.cnt', 'menu.price', 'menu.nm', 'menu.nm', 'menu.nm', 'menu.cnt', 'menu.price', 'menu.nm', 'menu.nm', 'menu.nm', 'menu.cnt', 'menu.price', 'menu.nm', 'menu.nm', 'menu.cnt', 'menu.price', 'menu.nm', 'menu.nm', 'menu.cnt', 'menu.price', 'menu.nm', 'menu.nm', 'menu.nm', 'menu.nm', 'menu.cnt', 'menu.price', 'sub_total.subtotal_price', 'sub_total.subtotal_price', 'sub_total.discount_price', 'sub_total.discount_price', 'total.total_price', 'total.total_price', 'total.creditcardprice', 'total.creditcardprice']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8ZL0JmyxsJEV"
      },
      "source": [
        "for k,v in encoding.items():\n",
        "  encoding[k] = v.unsqueeze(0).to(device)\n",
        "\n",
        "model.eval()\n",
        "# forward pass\n",
        "outputs = model(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'],\n",
        "                token_type_ids=encoding['token_type_ids'], bbox=encoding['bbox'],\n",
        "                image=encoding['image'])"
      ],
      "execution_count": 24,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "wDN8cNKG1PLQ",
        "outputId": "5536968d-1479-4993-eb49-07128765c0c2"
      },
      "source": [
        "prediction_indices = outputs.logits.argmax(-1).squeeze().tolist()\n",
        "print(prediction_indices)"
      ],
      "execution_count": 25,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[19, 14, 14, 14, 19, 19, 19, 19, 3, 21, 21, 21, 19, 19, 19, 19, 19, 3, 21, 21, 21, 19, 19, 19, 19, 19, 3, 21, 21, 21, 19, 19, 19, 19, 3, 21, 21, 21, 19, 19, 19, 19, 19, 3, 21, 19, 19, 19, 19, 19, 19, 19, 3, 21, 14, 14, 14, 14, 14, 14, 18, 18, 18, 18, 23, 23, 23, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 19, 19, 19, 19, 0, 0, 0, 19, 19, 19, 19, 19, 19, 0, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 19, 0, 0, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 19, 19, 0, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 0, 0, 19, 19, 19, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 19, 19, 0, 19, 0, 0, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 19, 0, 19, 19, 0, 19, 0, 19, 19, 19, 19, 19, 19, 19, 0, 19, 19, 0, 19, 19, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Df2EqPPLyafz",
        "outputId": "67c5f11f-1c03-4bbe-aa2d-e1fd97464a1d"
      },
      "source": [
        "prediction_indices = outputs.logits.argmax(-1).squeeze().tolist()\n",
        "predictions = [id2label[label] for gt, label in zip(encoding['labels'].squeeze().tolist(), prediction_indices) if gt != -100]\n",
        "print(predictions)"
      ],
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['sub_total.subtotal_price', 'menu.nm', 'menu.nm', 'menu.cnt', 'menu.price', 'menu.nm', 'menu.nm', 'menu.nm', 'menu.cnt', 'menu.price', 'menu.nm', 'menu.nm', 'menu.nm', 'menu.cnt', 'menu.price', 'menu.nm', 'menu.nm', 'menu.cnt', 'menu.price', 'menu.nm', 'menu.nm', 'menu.cnt', 'menu.price', 'menu.nm', 'menu.nm', 'menu.nm', 'menu.nm', 'menu.cnt', 'menu.price', 'sub_total.subtotal_price', 'sub_total.subtotal_price', 'sub_total.discount_price', 'sub_total.discount_price', 'total.total_price', 'total.total_price', 'total.creditcardprice', 'total.creditcardprice']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 49,
          "referenced_widgets": [
            "4bd5fc6d3bb64279b8594b0cffbb5d59",
            "97241634272d48afa7e5f4eef7b5311f",
            "ab33f31ca74740dd8ca9fc98cc5bdb25",
            "24a78c17aec14d308080ac07d36865dd",
            "0aee29ed9f394ccf81c126466f78b7c2",
            "e90a122a051b44fea59b3c7da4754f6d",
            "584c91d192ce452698d067c3998adde8",
            "9a16c3df03304d39a028dce62b5187b1",
            "2e6f420dfdc64372b0fdc497cd3543e3",
            "01b68a67df974c6991b9716e211f32f1",
            "1e85a9c975c244768006301dc00c7a31"
          ]
        },
        "id": "-QGUtQuIacyF",
        "outputId": "9e08df7e-36af-4e07-ee5e-9bba06dd76c6"
      },
      "source": [
        "import numpy as np\n",
        "\n",
        "preds_val = None\n",
        "out_label_ids = None\n",
        "\n",
        "# put model in evaluation mode\n",
        "model.eval()\n",
        "for batch in tqdm(test_dataloader, desc=\"Evaluating\"):\n",
        "    with torch.no_grad():\n",
        "        input_ids = batch['input_ids'].to(device)\n",
        "        bbox = batch['bbox'].to(device)\n",
        "        image = batch['image'].to(device)\n",
        "        attention_mask = batch['attention_mask'].to(device)\n",
        "        token_type_ids = batch['token_type_ids'].to(device)\n",
        "        labels = batch['labels'].to(device)\n",
        "\n",
        "        # forward pass\n",
        "        outputs = model(input_ids=input_ids, bbox=bbox, image=image, attention_mask=attention_mask, \n",
        "                        token_type_ids=token_type_ids, labels=labels)\n",
        "        \n",
        "        if preds_val is None:\n",
        "          preds_val = outputs.logits.detach().cpu().numpy()\n",
        "          out_label_ids = batch[\"labels\"].detach().cpu().numpy()\n",
        "        else:\n",
        "          preds_val = np.append(preds_val, outputs.logits.detach().cpu().numpy(), axis=0)\n",
        "          out_label_ids = np.append(\n",
        "              out_label_ids, batch[\"labels\"].detach().cpu().numpy(), axis=0\n",
        "          )"
      ],
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "4bd5fc6d3bb64279b8594b0cffbb5d59",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zPW1CISc0DL9"
      },
      "source": [
        "import warnings\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "from seqeval.metrics import (\n",
        "    classification_report,\n",
        "    f1_score,\n",
        "    precision_score,\n",
        "    recall_score)\n",
        "\n",
        "def results_test(preds, out_label_ids, labels):\n",
        "  preds = np.argmax(preds, axis=2)\n",
        "\n",
        "  label_map = {i: label for i, label in enumerate(labels)}\n",
        "\n",
        "  out_label_list = [[] for _ in range(out_label_ids.shape[0])]\n",
        "  preds_list = [[] for _ in range(out_label_ids.shape[0])]\n",
        "\n",
        "  for i in range(out_label_ids.shape[0]):\n",
        "      for j in range(out_label_ids.shape[1]):\n",
        "          if out_label_ids[i, j] != -100:\n",
        "              out_label_list[i].append(label_map[out_label_ids[i][j]])\n",
        "              preds_list[i].append(label_map[preds[i][j]])\n",
        "\n",
        "  results = {\n",
        "      \"precision\": precision_score(out_label_list, preds_list),\n",
        "      \"recall\": recall_score(out_label_list, preds_list),\n",
        "      \"f1\": f1_score(out_label_list, preds_list),\n",
        "  }\n",
        "  return results, classification_report(out_label_list, preds_list)"
      ],
      "execution_count": 28,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "dVuQ73870FMF",
        "outputId": "30662245-28a4-4d5e-b70a-67e0bc8d98d6"
      },
      "source": [
        "labels = list(set(all_labels))\n",
        "val_result, class_report = results_test(preds_val, out_label_ids, labels)\n",
        "print(\"Overall results:\", val_result)\n",
        "print(class_report)"
      ],
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Overall results: {'precision': 0.917298937784522, 'recall': 0.9166034874905231, 'f1': 0.9169510807736064}\n",
            "                         precision    recall  f1-score   support\n",
            "\n",
            "                enu.cnt       1.00      0.97      0.98       224\n",
            "      enu.discountprice       0.62      0.50      0.56        10\n",
            "       enu.itemsubtotal       0.00      0.00      0.00         6\n",
            "                 enu.nm       0.96      0.90      0.93       251\n",
            "                enu.num       0.85      1.00      0.92        11\n",
            "              enu.price       0.96      0.98      0.97       247\n",
            "            enu.sub_cnt       0.76      0.94      0.84        17\n",
            "             enu.sub_nm       0.52      0.91      0.66        32\n",
            "          enu.sub_price       0.81      0.85      0.83        20\n",
            "          enu.unitprice       0.96      0.96      0.96        68\n",
            "         otal.cashprice       0.94      0.92      0.93        71\n",
            "       otal.changeprice       0.93      0.97      0.95        59\n",
            "   otal.creditcardprice       0.88      0.82      0.85        17\n",
            "       otal.emoneyprice       0.33      0.50      0.40         2\n",
            "       otal.menuqty_cnt       0.80      0.83      0.81        29\n",
            "      otal.menutype_cnt       0.00      0.00      0.00         7\n",
            "         otal.total_etc       0.00      0.00      0.00         4\n",
            "       otal.total_price       0.97      0.92      0.94       101\n",
            "ub_total.discount_price       1.00      1.00      1.00         7\n",
            "           ub_total.etc       0.27      0.38      0.32         8\n",
            " ub_total.service_price       0.92      0.92      0.92        12\n",
            "ub_total.subtotal_price       0.87      0.96      0.91        69\n",
            "     ub_total.tax_price       0.89      0.83      0.86        47\n",
            "\n",
            "              micro avg       0.92      0.92      0.92      1319\n",
            "              macro avg       0.71      0.74      0.72      1319\n",
            "           weighted avg       0.92      0.92      0.91      1319\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6amDSVoxR7zU"
      },
      "source": [
        "The results I was getting were: \n",
        "\n",
        "`{'precision': 0.9307458143074582, 'recall': 0.9272175890826384, 'f1': 0.9289783516900872}`"
      ]
    }
  ]
}